Tidymodels workflow
penguins_df <- penguins %>%
filter(!is.na(sex)) %>% # remoce NA
select(-year, -island) # remove these variables as they are not useful in this case
glimpse(penguins_df)
Rows: 333
Columns: 6
$ species <fct> Adelie, Adelie, Adelie, Adelie, Adelie, Ad…
$ bill_length_mm <dbl> 39.1, 39.5, 40.3, 36.7, 39.3, 38.9, 39.2, …
$ bill_depth_mm <dbl> 18.7, 17.4, 18.0, 19.3, 20.6, 17.8, 19.6, …
$ flipper_length_mm <int> 181, 186, 195, 193, 190, 181, 195, 182, 19…
$ body_mass_g <int> 3750, 3800, 3250, 3450, 3650, 3625, 4675, …
$ sex <fct> male, female, female, female, male, female…
penguins_df %>%
count(sex)
# A tibble: 2 × 2
sex n
<fct> <int>
1 female 165
2 male 168
set.seed(20211124)
penguins_split <- initial_split(penguins)
penguins_train <- training(penguins_split)
penguins_test <- testing(penguins_split)
set.seed(2021112402)
penguins_cv <- vfold_cv(penguins_train, n = 5)
svm_model <-
svm_rbf(cost = tune(), rbf_sigma = tune()) %>%
set_mode("classification") %>%
set_engine("kernlab")
svm_recipe <- recipe(sex ~ ., data = penguins_df)
svm_workflow <-
workflow() %>%
add_model(svm_model) %>%
add_recipe(svm_recipe)
Parameters need to be tuned using tune_grid function
cost()
Cost (quantitative)
Transformer: log-2
Range (transformed scale): [-10, 5]
rbf_sigma()
Radial Basis Function sigma (quantitative)
Transformer: log-10
Range (transformed scale): [-10, 0]
# control object that specifies different aspects of the search
ctrl <- control_grid(verbose = FALSE, save_pred = TRUE)
# set metric
roc_res <- metric_set(roc_auc)
set.seed(2021112405)
# tuning results
svm_results <-
svm_model %>%
tune_grid(
svm_recipe,
resamples = penguins_cv,
metrics = roc_res,
control = ctrl
)
svm_results
# Tuning results
# 10-fold cross-validation
# A tibble: 10 × 5
splits id .metrics .notes .predictions
<list> <chr> <list> <list> <list>
1 <split [232/26]> Fold01 <tibble [0 × 6]> <tibble [… <tibble [0 × …
2 <split [232/26]> Fold02 <tibble [10 × 6]> <tibble [… <tibble [260 …
3 <split [232/26]> Fold03 <tibble [10 × 6]> <tibble [… <tibble [260 …
4 <split [232/26]> Fold04 <tibble [10 × 6]> <tibble [… <tibble [260 …
5 <split [232/26]> Fold05 <tibble [0 × 6]> <tibble [… <tibble [0 × …
6 <split [232/26]> Fold06 <tibble [10 × 6]> <tibble [… <tibble [260 …
7 <split [232/26]> Fold07 <tibble [10 × 6]> <tibble [… <tibble [260 …
8 <split [232/26]> Fold08 <tibble [10 × 6]> <tibble [… <tibble [260 …
9 <split [233/25]> Fold09 <tibble [10 × 6]> <tibble [… <tibble [250 …
10 <split [233/25]> Fold10 <tibble [10 × 6]> <tibble [… <tibble [250 …
collect_metrics(svm_results)
# A tibble: 10 × 8
cost rbf_sigma .metric .estimator mean n std_err .config
<dbl> <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
1 0.0382 1.90e- 6 roc_auc binary 0.793 8 0.0954 Preproce…
2 0.00799 1.08e- 5 roc_auc binary 0.793 8 0.0954 Preproce…
3 1.33 3.25e- 1 roc_auc binary 0.973 8 0.0112 Preproce…
4 0.00178 5.02e-10 roc_auc binary 0.764 8 0.0916 Preproce…
5 4.37 5.03e- 2 roc_auc binary 0.968 8 0.0107 Preproce…
6 15.0 4.59e- 4 roc_auc binary 0.960 8 0.00805 Preproce…
7 3.04 5.92e- 7 roc_auc binary 0.793 8 0.0954 Preproce…
8 0.00700 2.83e- 8 roc_auc binary 0.793 8 0.0954 Preproce…
9 0.101 2.14e- 9 roc_auc binary 0.785 8 0.0941 Preproce…
10 0.217 5.30e- 3 roc_auc binary 0.883 8 0.0195 Preproce…
# ggplot for accuracy vs cost
collect_metrics(svm_results) %>%
ggplot(aes(cost, mean)) +
geom_point() +
geom_line() +
labs(y = "mean accuracy for repeated cv",
title = "Average performance profile for SVM classification model") +
scale_x_continuous(n.breaks = 20, expand = c(0,0)) +
scale_y_continuous(labels = scales::number_format(), limits = c(0.5, 1),
expand = c(0,0)) +
theme_classic()
show_best(svm_results,
metric = "roc_auc")
# A tibble: 5 × 8
cost rbf_sigma .metric .estimator mean n std_err .config
<dbl> <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
1 1.33 0.325 roc_auc binary 0.973 8 0.0112 Preproces…
2 4.37 0.0503 roc_auc binary 0.968 8 0.0107 Preproces…
3 15.0 0.000459 roc_auc binary 0.960 8 0.00805 Preproces…
4 0.217 0.00530 roc_auc binary 0.883 8 0.0195 Preproces…
5 0.0382 0.00000190 roc_auc binary 0.793 8 0.0954 Preproces…
collect_predictions(svm_results) %>%
group_by(id) %>%
roc_curve(sex, .pred_female) %>%
ggplot(aes(1-specificity, sensitivity, col = id)) +
geom_abline(lty = 2, col = "grey50", size = 1.5) +
geom_path(show.legend = F, alpha = 0.6, size = 1.2) +
coord_equal() +
theme_classic()
# select best parameter
best_param <- svm_results %>%
select_best("roc_auc")
best_param
# A tibble: 1 × 3
cost rbf_sigma .config
<dbl> <dbl> <chr>
1 1.33 0.325 Preprocessor1_Model03
# update workflow object with values from select best
svm_final_workflow <- svm_workflow %>%
finalize_workflow(best_param)
svm_final_workflow
══ Workflow ══════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: svm_rbf()
── Preprocessor ──────────────────────────────────────────────────────
0 Recipe Steps
── Model ─────────────────────────────────────────────────────────────
Radial Basis Function Support Vector Machine Specification (classification)
Main Arguments:
cost = 1.32561534433381
rbf_sigma = 0.325258400684508
Computational engine: kernlab
svm_fit <- svm_final_workflow %>%
fit(data = penguins_train)
svm_fit
══ Workflow [trained] ════════════════════════════════════════════════
Preprocessor: Recipe
Model: svm_rbf()
── Preprocessor ──────────────────────────────────────────────────────
0 Recipe Steps
── Model ─────────────────────────────────────────────────────────────
Support Vector Machine object of class "ksvm"
SV type: C-svc (classification)
parameter : cost C = 1.32561534433381
Gaussian Radial Basis kernel function.
Hyperparameter : sigma = 0.325258400684508
Number of Support Vectors : 84
Objective Function Value : -79.8079
Training error : 0.056452
Probability model included.
# last fit on training dataset and evaluate on test dataset
final_fit <-
svm_final_workflow %>%
last_fit(penguins_split)
final_fit %>%
collect_metrics() # 0.906 : accuracy, #0.970 roc_auc
# A tibble: 2 × 4
.metric .estimator .estimate .config
<chr> <chr> <dbl> <chr>
1 accuracy binary 0.906 Preprocessor1_Model1
2 roc_auc binary 0.970 Preprocessor1_Model1
For attribution, please cite this work as
lruolin (2021, Nov. 24). pRactice corner: Support Vector Machine with Palmer Penguins. Retrieved from https://lruolin.github.io/myBlog/posts/20211124 - SVM with palmer penguins/
BibTeX citation
@misc{lruolin2021support, author = {lruolin, }, title = {pRactice corner: Support Vector Machine with Palmer Penguins}, url = {https://lruolin.github.io/myBlog/posts/20211124 - SVM with palmer penguins/}, year = {2021} }